{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g2KnYlEyrifF"
      },
      "source": [
        "This Colab shows how to load the provided `.npz` file with rank-$49$ factorizations of $\\boldsymbol{\\mathscr{T}}_4$ in standard arithmetic, and how to compute the invariants $\\mathscr{R}$ and $\\mathscr{K}$ in order to demonstrate that these factorizations are mutually nonequivalent. For more details, see Supplement B of the paper.\n",
        "\n",
        "- Copyright 2022 DeepMind Technologies Limited\n",
        "- All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0\n",
        "- All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY).  You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode\n",
        "- Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.\n",
        "- This is not an official Google product."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LAx2AaeInFsQ"
      },
      "outputs": [],
      "source": [
        "from google.colab import files\n",
        "from typing import Hashable, List, Optional, Tuple\n",
        "\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CvbKWr5cnb54"
      },
      "outputs": [],
      "source": [
        "# Upload the provided `.npz` file containing the factorizations.\n",
        "chosen_file = list(files.upload().keys())[0]\n",
        "factorizations = np.load(open(chosen_file, 'rb'))['factorizations']\n",
        "print(factorizations.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rGAvoMN1YsEH"
      },
      "source": [
        "**Compute basic properties of the factorizations**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ONwOVPjMYwCg"
      },
      "outputs": [],
      "source": [
        "entries = list(np.unique(factorizations))\n",
        "print('Factors include {} as entries.'.format(', '.join(map(str, entries))))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v_J98exRY9pR"
      },
      "outputs": [],
      "source": [
        "nonzeros = np.count_nonzero(factorizations, axis=(-3, -2, -1))\n",
        "print('Nonzeros in a factorization range from {} to {}.'.format(\n",
        "    np.min(nonzeros), np.max(nonzeros)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e4WxaaQ-ngil"
      },
      "source": [
        "**Code to compute the invariants $\\mathscr{R}$ and $\\mathscr{K}$**\n",
        "\n",
        "In code comments we write `S` for the tensor size (equals $16$ for $\\boldsymbol{\\mathscr{T}}_4$) and `R` for the rank (equals $49$ in the provided `.npz` file)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8-EXKPx9njIw"
      },
      "outputs": [],
      "source": [
        "def _matricize_factorization(factorization: np.ndarray) -\u003e np.ndarray:\n",
        "  \"\"\"Transforms [R, 3, S] `factorization` into [R, 3, sqrt S, sqrt S].\"\"\"\n",
        "  rank, _, tensor_size = factorization.shape\n",
        "  matrix_size = int(np.sqrt(tensor_size))\n",
        "  return factorization.reshape((rank, 3, matrix_size, matrix_size))\n",
        "\n",
        "\n",
        "def compute_invariant_r(factorization: np.ndarray) -\u003e Hashable:\n",
        "  \"\"\"Returns the matrix rank invariant of `factorization`.\n",
        "\n",
        "  The matrix rank invariant of a factorization {U_r, V_r, W_r}_{r=1}^R is\n",
        "      { { rank(U_r), rank(V_r), rank(W_r) } }_{r=1}^R\n",
        "  where {.} denotes an unordered tuple, and rank(X_r) is the matrix rank of the\n",
        "  factor X_r when seen as a square matrix. See Supplementary Information of the\n",
        "  paper for more details.\n",
        "\n",
        "  Args:\n",
        "    factorization: [R, 3, S] array representing {U_r, V_r, W_r}_{r=1}^R\n",
        "  Returns:\n",
        "    The matrix rank invariant of `factorization`. The unordered tuples are\n",
        "    returned in a canonical form (sorted).\n",
        "  \"\"\"\n",
        "  matricized = _matricize_factorization(factorization)  # [R, 3, sqrt S, sqrt S]\n",
        "  ranks = np.linalg.matrix_rank(matricized)  # [R, 3]\n",
        "  return tuple(sorted(tuple(sorted(factor_ranks)) for factor_ranks in ranks))\n",
        "\n",
        "\n",
        "def _action_abc(a: np.ndarray, b: np.ndarray, c: np.ndarray,\n",
        "                matricized_factorization: np.ndarray) -\u003e np.ndarray:\n",
        "  \"\"\"Applies the (A, B, C) action to `matricized_factorization`.\"\"\"\n",
        "  inv_a = np.linalg.inv(a)  # [sqrt(S), sqrt(S)]\n",
        "  inv_b = np.linalg.inv(b)  # [sqrt(S), sqrt(S)]\n",
        "  inv_c = np.linalg.inv(c)  # [sqrt(S), sqrt(S)]\n",
        "  return np.array([[a @ u @ inv_b, b @ v @ inv_c, c @ w @ inv_a]\n",
        "                   for u, v, w in matricized_factorization])\n",
        "\n",
        "def _phi(matricized_factorization: np.ndarray) -\u003e Optional[np.ndarray]:\n",
        "  \"\"\"Canonicalizes the factorization to reveal a (I, I, I) factor.\n",
        "\n",
        "  This corresponds to the map Phi described in Supplement B of the paper.\n",
        "\n",
        "  Args:\n",
        "    matricized_factorization: [R, 3, sqrt S, sqrt S] array\n",
        "  Returns:\n",
        "    - `None` if the factorization does not have exactly one factor such that\n",
        "      U_r V_r W_r equals the identity matrix.\n",
        "    - Otherwise, an [R, 3, sqrt S, sqrt S] array representing the matricized\n",
        "      factorization after applying the (A, B, C) action that transforms the\n",
        "      factor (U_r, V_r, W_r) satisfying U_r V_r W_r = I into (I, I, I).\n",
        "  \"\"\"\n",
        "  matrix_size = matricized_factorization.shape[-1]\n",
        "\n",
        "  # Check if there is exactly one factor with U_r V_r W_r = I.\n",
        "  u, v, w = np.moveaxis(matricized_factorization, 1, 0)\n",
        "  is_uvw_identity = np.all(u @ v @ w == np.eye(matrix_size), axis=(1, 2))  # [R]\n",
        "  if np.sum(is_uvw_identity) != 1:\n",
        "    return None\n",
        "  identity_index = np.argmax(is_uvw_identity)  # []\n",
        "\n",
        "  # Apply (A, B, C) action that transforms this factor into (I, I, I).\n",
        "  u1, v1, w1 = matricized_factorization[identity_index]\n",
        "  matrix_a = np.eye(matrix_size)  # [sqrt(S), sqrt(S)]\n",
        "  matrix_c = np.linalg.inv(w1)  # [sqrt(S), sqrt(S)]\n",
        "  matrix_b = matrix_c @ np.linalg.inv(v1)  # [sqrt(S), sqrt(S)]\n",
        "  return _action_abc(matrix_a, matrix_b, matrix_c, matricized_factorization)\n",
        "\n",
        "def compute_invariant_k(factorization: np.ndarray) -\u003e Hashable:\n",
        "  \"\"\"Computes the more granular invariant of `factorization`, or returns `()`.\n",
        "\n",
        "  See Supplement B of the paper for a detailed description of this invariant,\n",
        "  and for a proof of its correctness.\n",
        "\n",
        "  Args:\n",
        "    factorization: [R, 3, S] array representing {U_r, V_r, W_r}_{r=1}^R\n",
        "  Returns:\n",
        "    The invariant K of `factorization`. For factorizations where this invariant\n",
        "    is not applicable, a constant dummy value is returned (so that all these\n",
        "    factorizations are considered indistinguishable using this invariant).\n",
        "  \"\"\"\n",
        "  matricized = _matricize_factorization(factorization)  # [R, 3, sqrt S, sqrt S]\n",
        "  canonicalized = _phi(matricized)  # [R, 3, sqrt S, sqrt S]\n",
        "  if canonicalized is None:\n",
        "    return ()  # Dummy value when the invariant is not applicable.\n",
        "  return frozenset(\n",
        "      # `np.poly` computes the coefficients of the characteristic polynomial\n",
        "      # of the Kronecker product u x v x w.\n",
        "      tuple(np.round(np.poly(np.kron(np.kron(u, v), w))).astype(np.int32))\n",
        "      for u, v, w in canonicalized)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FmBGGUpRpwmh"
      },
      "source": [
        "**Run invariant computations on the loaded factorizations**\n",
        "\n",
        "This can take up to 10 minutes to complete, but there will be intermediate outputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ckWpaIA5a3CA"
      },
      "outputs": [],
      "source": [
        "invariants: List[Tuple[Hashable, Hashable]] = []\n",
        "for fi, current_factorization in enumerate(factorizations):\n",
        "  invariant_r = compute_invariant_r(current_factorization)\n",
        "  invariant_k = compute_invariant_k(current_factorization)\n",
        "  invariants.append((invariant_r, invariant_k))\n",
        "  if (fi + 1) % 500 == 0 or (fi + 1) == len(factorizations):\n",
        "    print('After processing {}/{} factorizations: unique invariant values '\n",
        "          '(number of nonequivalent factorizations): {}'.format(\n",
        "              fi + 1, len(factorizations), len(set(invariants))))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JgfgIFwF4Cwn"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "inspect_factorizations_notebook.ipynb",
      "private_outputs": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
